Table of Contents

  • 1  Forecast of COVID-19 cases
  • 2  Risk assessment of COVID-19 cases
    • 2.1  Data
      • 2.1.1  Summary Statistics
    • 2.2  Method: Regression Tree
      • 2.2.1  Paper Model
        • 2.2.1.1  Different representation of regression tree
        • 2.2.1.2  Variable Importance Plot
      • 2.2.2  Add 'Total deaths' and 'Climate Zones' variable to the tree
      • 2.2.3  RT Europa
    • 2.3  Metrics

Log transform data?

Forecast of COVID-19 cases¶

In [9]:
arima_data = pd.read_excel("../data/Real_time_forecast_dataset_04_04_20.xlsx")

#Fix range
length=arima_data.count()
for k,i in enumerate(length):
    arima_data[arima_data.columns[k]]=arima_data[arima_data.columns[k]].shift(periods=len(arima_data)-i)
arima_data['date'] = pd.date_range(end='4/4/2020', periods=len(arima_data), freq='D')

arima_data = arima_data.set_index(arima_data['date']).drop(columns='date')
arima_data.plot(
    kind='line', stacked=True,
    figsize = (15,6)
).set_title('Forecast of COVID-19 cases',
            fontfamily='Tahoma',
            fontsize='x-large',
            fontstyle='italic',
            fontweight ='extra bold',
            fontvariant='small-caps');

Risk assessment of COVID-19 cases¶

At the outset of the COVID-19 outbreak, data on country-wise case fatality rates due to COVID-19 were obtained for 50 affected countries. The case fatality rate (CFR) can be crudely defined as the number of deaths in persons who tested positive for COVID-19 divided by the confirmed number of COVID-19 cases.

In this section, we are going to find out a list of essential causal variables that have strong influences on the CFR. The datasets and codes of this section are made publicly available at https://github.com/indrajitg-r/COVID for the reproducibility of this work.

Data¶

The CFR modeling dataset consists of 50 observations having ten possible causal variables and one numerical output variable.

The possible causal variables considered in this study are the followings:

  • the total number of COVID-19 cases (in thousands) in the country till 4 April, 2020,
  • population density per km2 for the country,
  • total population (in millions) of the country (approx.),
  • percentage of people in the age group of greater than 65 years,
  • lockdown days count (from the starting day of lockdown till April 4, 2020),
  • time-period (in days) of COVID-19 cases for the country (starting date to April 4, 2020),
  • doctors per 1000 people in the country,
  • hospital beds per 1000 people in the country,
  • income standard (e.g., high or lower) of the country,
  • climate zones (e.g., tropical, subtropical or moderate) of the country.

The dataset contains a total of 8 numerical input variables and two categorical input variables.

In [11]:
# Importing the necessary libraries
import pandas as pd
from IPython.core.display import HTML

# Download from kaggle the dataset containing the flags URL
import zipfile
zf = zipfile.ZipFile("../data/archive.zip")
zf.infolist()
flags = pd.read_csv(zf.open('countries_continents_codes_flags_url.csv'), usecols=[0,2])
flags.columns = flags.columns.str.title()

#Fix the dataset names in order to merge them
flags.replace('South Korea', 'S. Korea', inplace=True)
flags.replace('United States', 'USA', inplace=True)
flags.replace('United Kingdom', 'UK', inplace=True)
flags.replace('Czech Republic', 'Czechia', inplace=True)

flags = pd.merge(covid_data, flags, 'left', on='Country').set_index('Country')
flags.columns = flags.columns.str.title()
flags = flags.sort_values(by=['Cases In Thousands'],ascending= False)

# Converting links to html tags
def path_to_image_html(path):
    return '<img src="'+ path + '" width="50" >'
#Read the html code, provided by to_html function, with the function HTML
#HTML(flags.to_html(escape=False,formatters=dict(Image_Url=path_to_image_html)))

cols = flags.columns.tolist()
cols = cols[-1:] + cols[:-1]
flags = flags[cols] 

#Display clear readble table
flags =(np.round(flags,decimals=3).rename(columns={'% People (>65)': 'People (>65)',
                                                      'Population Density/Km2': 'Population Density',
                                                      'Cases In Thousands': 'Cases', 
                                                      'Population (In Millions)':'Population',
                                                  'Image_Url': 'Flag'})
).style.format(formatter={('People (>65)'): lambda x: "{:,.1f}%".format(x),
                          ('Population Density'): lambda x: "{:,} /$km^2$".format(x),
                          ('Cases'): lambda x: '{:.3f}$K$'.format(x),
                          ('Population'): lambda x: '{:.2f}$M$'.format(x),
                          ('Doctors Per 1000 People'): lambda x: '{:.2f}'.format(x),
                          ('Hospital Beds Per 1000'): lambda x: '{:.2f}'.format(x),
                          ('Flag'): lambda x: path_to_image_html(x)})

styles = [
    dict(selector="tr:hover",
         props=[("background-color", "%s" % "#ffff99")]),
    dict(selector="th", props=[("font-size", "110%"),
                               ("text-align", "center")]),
    dict(selector="caption", props=[("caption-side", "top"),
                                    ("font-size", "150%"),
                                    ("text-align", "center")])]

flags = (flags.set_table_styles(styles)
          .background_gradient(cmap= sns.light_palette("red", as_cmap=True),
                               subset=['Cases'])
          .set_caption("CFR Dataset"))

flags
Out[11]:
CFR Dataset
  Flag Cases Population Population Density People (>65) No. Of Days Since Shutdown Time Of Arival (Till Today) Doctors Per 1000 People Hospital Beds Per 1000 Income Class Climate Zones Cfr Total Deaths
Country                          
USA 277.965$K$ 329.55$M$ 34 /$km^2$ 15.4% 16 75 2.57 2.90 1 1 0.026000 7157
Italy 119.827$K$ 60.25$M$ 200 /$km^2$ 23.0% 26 65 4.02 3.40 1 1 0.123000 14681
Spain 117.710$K$ 46.93$M$ 93 /$km^2$ 19.4% 20 64 3.87 3.10 1 0 0.093000 10935
China 82.527$K$ 1402.01$M$ 145 /$km^2$ 10.6% 73 84 1.81 3.80 0 0 0.040000 3330
Germany 79.696$K$ 83.15$M$ 233 /$km^2$ 21.5% 21 68 4.19 8.20 1 1 0.013000 1017
France 64.338$K$ 67.06$M$ 123 /$km^2$ 19.7% 22 71 3.24 6.40 1 1 0.101000 6507
Iran 50.468$K$ 83.33$M$ 51 /$km^2$ 5.4% 46 76 1.49 0.10 0 0 0.063000 3160
UK 38.168$K$ 66.43$M$ 274 /$km^2$ 18.5% 19 64 2.83 2.90 1 1 0.094000 3605
Turkey 20.921$K$ 83.15$M$ 106 /$km^2$ 8.2% 19 25 1.75 2.50 0 0 0.020000 425
Switzerland 19.706$K$ 8.58$M$ 208 /$km^2$ 18.4% 20 39 4.25 5.00 1 1 0.031000 607
Belgium 16.770$K$ 11.52$M$ 376 /$km^2$ 18.6% 18 60 3.01 6.50 1 1 0.068000 1143
Netherlands 15.723$K$ 17.45$M$ 420 /$km^2$ 18.8% 17 37 3.48 4.70 1 1 0.095000 1487
Canada 12.519$K$ 31.98$M$ 4 /$km^2$ 17.0% 20 70 2.54 2.70 1 1 0.015000 187
Austria 11.525$K$ 8.90$M$ 106 /$km^2$ 19.2% 20 39 5.23 7.60 1 1 0.015000 168
S. Korea 10.156$K$ 51.78$M$ 517 /$km^2$ 13.9% 13 75 2.33 10.30 1 0 0.017000 177
Portugal 9.886$K$ 10.28$M$ 112 /$km^2$ 21.5% 17 62 4.43 3.40 1 0 0.025000 246
Brazil 9.056$K$ 211.33$M$ 25 /$km^2$ 8.6% 19 39 1.85 2.30 0 0 0.040000 359
Israel 7.428$K$ 9.18$M$ 416 /$km^2$ 11.7% 13 43 3.58 3.30 1 -1 0.005000 39
Sweden 6.078$K$ 10.37$M$ 23 /$km^2$ 19.9% 11 64 4.19 2.70 1 1 0.055000 333
Australia 5.548$K$ 25.66$M$ 3 /$km^2$ 15.5% 13 70 3.50 3.90 1 -1 0.005000 30
Norway 5.208$K$ 5.37$M$ 17 /$km^2$ 16.8% 20 38 4.38 3.30 1 1 0.008000 44
Ireland 4.273$K$ 4.92$M$ 70 /$km^2$ 13.9% 24 35 2.96 2.90 1 1 0.028000 120
Czechia 4.190$K$ 10.68$M$ 135 /$km^2$ 19.0% 24 33 3.68 6.80 1 1 0.013000 53
Russia 4.149$K$ 146.88$M$ 9 /$km^2$ 14.2% 22 64 3.98 9.70 0 1 0.008000 34
Denmark 3.757$K$ 5.81$M$ 135 /$km^2$ 19.7% 22 37 3.65 3.50 1 1 0.037000 139
Poland 3.383$K$ 38.39$M$ 123 /$km^2$ 16.8% 21 31 2.29 6.50 1 1 0.021000 71
Ecuador 3.368$K$ 17.46$M$ 63 /$km^2$ 7.1% 20 35 1.67 1.60 0 1 0.043000 145
Malaysia 3.333$K$ 32.74$M$ 99 /$km^2$ 6.3% 20 70 1.53 1.90 0 0 0.016000 53
Romania 3.183$K$ 19.40$M$ 81 /$km^2$ 17.9% 15 38 2.67 6.10 0 1 0.042000 133
Philippines 3.018$K$ 108.48$M$ 362 /$km^2$ 4.8% 22 65 1.11 1.00 0 0 0.045000 136
Japan 2.935$K$ 126.01$M$ 333 /$km^2$ 27.0% 34 80 2.37 13.70 1 0 0.024000 69
India 2.902$K$ 1360.49$M$ 414 /$km^2$ 6.0% 13 65 0.76 0.70 0 0 0.023000 68
Luxembourg 2.612$K$ 0.61$M$ 237 /$km^2$ 14.3% 0 35 2.92 5.40 1 1 0.012000 31
Pakistan 2.291$K$ 219.14$M$ 273 /$km^2$ 4.5% 15 38 0.98 0.60 0 0 0.014000 31
Indonesia 1.986$K$ 268.07$M$ 141 /$km^2$ 5.3% 21 33 0.20 0.90 0 0 0.091000 181
Mexico 1.688$K$ 126.58$M$ 64 /$km^2$ 6.9% 16 35 2.23 1.50 0 -1 0.036000 60
Panama 1.673$K$ 4.16$M$ 56 /$km^2$ 7.9% 14 25 1.59 2.20 1 0 0.025000 41
Finland 1.615$K$ 5.52$M$ 16 /$km^2$ 21.2% 13 66 3.20 5.50 1 1 0.012000 20
Greece 1.613$K$ 10.72$M$ 81 /$km^2$ 20.4% 21 38 6.25 4.80 1 0 0.037000 59
Peru 1.595$K$ 32.16$M$ 25 /$km^2$ 7.2% 20 29 1.12 1.50 0 0 0.038000 61
Dominican Republic 1.488$K$ 10.36$M$ 216 /$km^2$ 7.0% 20 34 1.49 1.70 0 0 0.046000 68
Serbia 1.476$K$ 6.90$M$ 89 /$km^2$ 17.4% 17 29 2.46 5.40 0 1 0.026000 39
Colombia 1.267$K$ 46.22$M$ 40 /$km^2$ 7.6% 20 28 1.82 1.50 0 0 0.020000 25
Argentina 1.265$K$ 44.94$M$ 16 /$km^2$ 11.2% 20 32 3.21 4.70 0 -1 0.029000 37
Ukraine 0.987$K$ 41.90$M$ 69 /$km^2$ 16.5% 23 32 3.00 9.00 0 1 0.023000 23
Algeria 0.986$K$ 43.00$M$ 18 /$km^2$ 6.2% 17 39 1.21 1.70 0 -1 0.084000 83
Egypt 0.779$K$ 100.19$M$ 100 /$km^2$ 5.2% 17 50 0.81 0.50 0 -1 0.067000 52
Iraq 0.772$K$ 39.31$M$ 90 /$km^2$ 3.2% 19 42 0.85 1.30 0 -1 0.070000 54
Morocco 0.708$K$ 35.86$M$ 80 /$km^2$ 6.8% 22 33 0.62 0.90 0 -1 0.062000 44
San Marino 0.251$K$ 0.03$M$ 568 /$km^2$ 16.0% 15 37 6.36 3.80 1 1 0.127000 32

Summary Statistics¶

In [12]:
covid_data.describe().round(2)
Out[12]:
cases in thousands population (in millions) population density/km2 % People (>65) no. of days since shutdown time of arival (till today) Doctors per 1000 people Hospital beds per 1000 CFR Total deaths
count 50.00 50.00 50.00 50.00 50.00 50.00 50.00 50.00 50.00 50.00
mean 20.90 110.62 149.78 13.58 20.20 48.72 2.71 3.93 0.04 1151.98
std 46.78 271.40 142.73 6.21 9.80 17.58 1.41 2.87 0.03 2864.53
min 0.25 0.03 3.00 3.20 0.00 25.00 0.20 0.10 0.01 20.00
25% 1.63 10.36 52.25 7.12 16.25 35.00 1.61 1.70 0.02 44.00
50% 3.57 37.12 99.50 14.85 20.00 39.00 2.62 3.30 0.03 77.00
75% 12.27 83.15 214.00 18.75 21.00 65.00 3.64 5.40 0.06 352.50
max 277.96 1402.01 568.00 27.00 73.00 84.00 6.36 13.70 0.13 14681.00

Method: Regression Tree¶

For the risk assessment with the CFR dataset for 50 countries, we apply the Regression Tree (RT), a non-parametric supervised learning method used for regression. The goal is to create a model that predicts the value of a target variable by learning simple decision rules inferred from the data features.

The corresponding machine learning algorithm is Classification and Regression Trees (CART).

The basic idea behind the algorithm is to find the point in the independent variable to split the data-set into 2 parts, so that the mean squared error is the minimized at that point. In other words it takes a feature and determines which cut-off point minimizes the variance of $y$ for a regression task, as the variance tells us how much the $y$ values in a node are spread around their mean value $\bar{y}$. As a consequence, the best cut-off point makes the two resulting subsets as different as possible with respect to the target outcome

The algorithm continues this search-and-split recursively and different subsets of the dataset are created until a stop criterion is reached. Possible criteria are: A minimum number of instances that have to be in a node before the split, or the minimum number of instances that have to be in a terminal node.

The intermediate subsets are called internal nodes (or split nodes) and the final subsets are called terminal (or leaf nodes). To predict the outcome in each leaf node, the average outcome of the training data in this node is used.

Paper Model¶

Decision trees can also be applied to regression problems, using the DecisionTreeRegressor class.

Basic regression trees partition a data set into smaller subgroups and then fit a simple constant for each observation in the subgroup. The partitioning is achieved by successive binary partitions (aka recursive partitioning) based on the different predictors. The constant to predict is based on the average response values for all observations that fall in that subgroup.

  • The default priors are proportional to the data counts
  • The losses default to 1
  • The split defaults to 'Gini'
Considerations¶

After running the model proposed by the authors, the regression tree (RT) did not correspond to the one published in the paper. Looking at the variable importance graph we noticed that the variables considered were 9 instead of 10 as initially mentioned by the authors.

Therefore, we have also excluded the variable 'Climate Zones' (x.x10).

In [13]:
from sklearn.tree import DecisionTreeRegressor
from sklearn import tree
In [14]:
#We exclude Total Deaths and Climate Zones as in the paper
X = covid_data.drop(columns=['CFR', 'Total deaths', 'Climate zones'])
y = covid_data['CFR']

from sklearn.tree import DecisionTreeRegressor
from sklearn import tree

# We equal the parameters as the control parameters of the corresponfing R function 'rpart' used in the paper
model = tree.DecisionTreeRegressor(criterion= "mse", # $method='anova'
                                   min_samples_split = 5, # $minsplit = 5
                                   max_depth=30, # $maxdepth
                                   min_samples_leaf=2) #$minbucket
model.fit(X,y);
In [50]:
plt.figure(figsize=(100,70))
features = X.columns.str.title()
tree.plot_tree(model,fontsize=40, feature_names=features,
               filled=True, node_ids=False, rounded=True)
plt.show()

Different representation of regression tree¶

In [25]:
# In case I was interested in the prediction of a single observation:
# observed = np.array(covid_data)[np.random.randint(0, len(covid_data)),:]
from dtreeviz.trees import dtreeviz
viz = dtreeviz(model, X, y,
               target_name="CFR",
               feature_names= features,
               title='Regression Tree',
               scale=0.8, orientation="LR",
               show_node_labels = False,
               #X=observed
               colors={'title':'black',
                       'text':'#14213d',
                       'arrow':'#455e89',
                       'scatter_marker':'#a01a58',
                       'tick_label':'grey',
                       'split_line':'#CED4DA'})
viz
Out[25]:
G Regression Tree node1 2021-12-31T19:32:58.817026 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node4 2021-12-31T19:33:01.184731 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ leaf2 2021-12-31T19:33:01.472494 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node1->leaf2 leaf3 2021-12-31T19:33:01.586067 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node1->leaf3 node10 2021-12-31T19:32:58.953983 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ leaf11 2021-12-31T19:33:01.999298 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node10->leaf11 leaf12 2021-12-31T19:33:02.106795 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node10->leaf12 leaf13 2021-12-31T19:33:02.220437 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node9 2021-12-31T19:32:59.096125 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node9->node10 node9->leaf13 node7 2021-12-31T19:32:59.238672 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node7->node9 leaf8 2021-12-31T19:33:01.701817 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node7->leaf8 leaf14 2021-12-31T19:33:02.336962 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node6 2021-12-31T19:32:59.374896 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node6->node7 node15 2021-12-31T19:33:00.223372 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node6->leaf14 node19 2021-12-31T19:32:59.522599 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ leaf20 2021-12-31T19:33:02.566913 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node19->leaf20 leaf21 2021-12-31T19:33:02.680858 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node19->leaf21 leaf22 2021-12-31T19:33:02.787051 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node18 2021-12-31T19:32:59.669071 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node18->node19 node18->leaf22 node16 2021-12-31T19:32:59.800249 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node16->node18 node23 2021-12-31T19:33:00.097126 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ leaf17 2021-12-31T19:33:02.450573 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node16->leaf17 node24 2021-12-31T19:32:59.950421 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ leaf25 2021-12-31T19:33:02.901123 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node24->leaf25 leaf26 2021-12-31T19:33:03.009372 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node24->leaf26 leaf27 2021-12-31T19:33:03.128708 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node23->node24 node23->leaf27 node15->node16 node15->node23 node5 2021-12-31T19:33:00.407265 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node5->node6 node5->node15 node28 2021-12-31T19:33:01.034691 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node31 2021-12-31T19:33:00.601281 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ leaf32 2021-12-31T19:33:03.355266 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node31->leaf32 leaf33 2021-12-31T19:33:03.467611 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node31->leaf33 node29 2021-12-31T19:33:00.740978 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node29->node31 node34 2021-12-31T19:33:00.881588 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ leaf30 2021-12-31T19:33:03.241080 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node29->leaf30 leaf35 2021-12-31T19:33:03.574802 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node34->leaf35 leaf36 2021-12-31T19:33:03.689719 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node34->leaf36 node28->node29 node28->node34 node4->node5 node4->node28 node0 2021-12-31T19:33:01.337842 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node0->node1 ≤ node0->node4 >

Variable Importance Plot¶

In [48]:
(pd.Series(model.feature_importances_,
           index= X.columns.str.title())
   .nsmallest(10) #To plot the 5 most important variables
   .plot(kind='barh',
         title = 'Variable Importance',
         figsize = [8,5],
         table = False,
         fontsize = 13,
         color = '#2e6f95',
         align='edge', width=0.8
         ));

Add 'Total deaths' and 'Climate Zones' variable to the tree¶

In [47]:
X_2 = covid_data.drop(columns=['CFR'])
# We equal the parameters as the control parameters of the corresponfing R function 'rpart' used in the paper
model2 = tree.DecisionTreeRegressor(criterion= "mse", # $method='anova'
                                   min_samples_split = 5, # $minsplit = 5
                                   max_depth=30, # $maxdepth
                                   min_samples_leaf=2) #$minbucket
model2.fit(X_2,y)

#plot the regression tree
plt.figure(figsize=(80,50))
features = X_2.columns
tree.plot_tree(model2,
               feature_names=features,
               filled=True,
               fontsize=40)
plt.show()

RT Europa¶

In [20]:
world = pd.read_csv("https://raw.githubusercontent.com/dbouquin/IS_608/master/NanosatDB_munging/Countries-Continents.csv")
world.replace('US', 'USA', inplace=True)
world.replace('United Kingdom', 'UK', inplace=True)

final = pd.merge(covid_data, world, 'left', on='Country')

final1 = final.groupby('Continent')
Africa = final1.get_group('Africa') #3 observations
Asia = final1.get_group('Asia') #11 observations
Europe = final1.get_group('Europe') #22 observation
N_America = final1.get_group('North America') #3 observation
S_America = final1.get_group('South America') #5 observation
Oceania = final1.get_group('Oceania') #1 observation
In [44]:
X3 = Europe.drop(columns=['CFR', 'Continent', 'Country', 'Total deaths', 'Climate zones'])
y3 = Europe['CFR']

#min split 10%, so in this case =2
model3 = tree.DecisionTreeRegressor(criterion= "mse", # $method='anova'
min_samples_split = 2, # $minsplit = 2
max_depth=30, # $maxdepth
min_samples_leaf=2) #$minbucket
model3.fit(X3,y3)

plt.figure(figsize=(80,50))
features = X3.columns
tree.plot_tree(model3,
               feature_names=features,
               filled=True,
               fontsize=50)
plt.show()
In [22]:
(pd.Series(model3.feature_importances_,
           index= X3.columns.str.title())
   .nsmallest(10) #To plot the 5 most important variables
   .plot(kind='barh',
         title = 'Variable Importance',
         figsize = [8,5],
         table = False,
         fontsize = 13,
         color = '#4E8d95',
         align='edge', width=0.8));

Metrics¶

When assessing how well a model fits a dataset, we use the Root Mean Squared Error (RMSE). The RMSE is a metric computed as the square root of the average squared difference between the predicted values and the actual values in a dataset:

$$RMSE=\sqrt{\sum{(\hat{y_i}-y_i)^2}\over{n}}$$

where:

  • $\hat{y}$ is the predicted value for the $ith$ observation
  • $y_i$ is the observed value for the $ith$ observation
  • $n$ is the sample size

The Mean Absolute Error (MAE) is a measure of errors between paired observations expressing the same phenomenon

$$MAE = {\sum{|\hat{y_i}-y_i|}\over{n}}$$

$MAE$ is conceptually simpler and also easier to interpret than $RMSE$: it is simply the average absolute vertical or horizontal distance between each point in a scatter plot as it is the average absolute difference between $\hat{y_i}$ and $y_i$.

Furthermore, each error contributes to $MAE$ in proportion to the absolute value of the error. This is in contrast to $RMSE$, which involves squaring the differences, so that a few large differences will increase the $RMSE$ to a greater degree than the $MAE$

In [23]:
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import r2_score

y_pred = model.predict(X)

RMSE = np.round(mean_squared_error(y, y_pred, squared=False), 4) 
MAE = np.round(mean_absolute_error(y, y_pred),4)

r2 = np.round(r2_score(y, y_pred),2)
Adj_r2 = np.round(1 - (1-r2_score(y, y_pred)) * (len(y)-1)/(len(y)-X.shape[1]-1),2)

y_pred2 = model2.predict(X_2)

RMSE_2 = round(mean_squared_error(y, y_pred2, squared=False), 4)
MAE_2 = round(mean_absolute_error(y, y_pred2),4)

r2_2 = round(r2_score(y, y_pred2),2)
Adj_r2_2 = round(1 - (1-r2_score(y, y_pred2)) * (len(y)-1)/(len(y)-X_2.shape[1]-1),2)

(pd.DataFrame(
    {'RMSE':[RMSE, RMSE_2],
     'MAE' :[MAE, MAE_2],
     'R^2': [r2, r2_2],
     'Adjusted R^2': [Adj_r2, Adj_r2_2]
     },
    index = ['Paper Model Metrics', 'Our Model Metrics'])
 .style.set_caption("Models Metrics")
 .set_table_styles([{
     #Caption
     'selector': 'caption',
     'props': 'caption-side: top; font-size:1.3em;'}])
 .format(formatter={('RMSE'): lambda x: "{:,.4f}".format(x),
                          ('MAE'): lambda x: "{:,.4f}".format(x),
                          ('R^2'): lambda x: '{:,.2f}'.format(x),
                          ('Adjusted R^2'): lambda x: '{:,.2f}'.format(x)})
)
Out[23]:
Models Metrics
  RMSE MAE R^2 Adjusted R^2
Paper Model Metrics 0.0103 0.0066 0.89 0.87
Our Model Metrics 0.0089 0.0059 0.92 0.90

We get different results for the same algorithm on the same machine implemented by different languages, such as R and Python.

We think that these small differences in the implementation of the underlying math libraries used will cause differences in the resulting model and predictions made by that model.

Infact the RMSE reported in the paper corresponds to $0.013$. Whereas the RMSE computed by ourselves using R was $0.012$ and using Python was $0.010$ as shown in the chunk before.

This tells us that the average deviation between the predicted points scored and the actual points scored is 0.01.